"""
统计函数：均值，方差等
"""
import torch


def test_statistic():
    print("==== statistic ====")
    a = torch.rand(2, 2)
    print("a")
    print(a)
    print()

    print("torch.mean(a, dim=0)")
    print(torch.mean(a, dim=0))
    print()

    print("torch.sum(a, dim=0)")
    print(torch.sum(a, dim=0))
    print()

    print("torch.prod(a, dim=0)")
    print(torch.prod(a, dim=0))
    print()

    print("torch.std(a)")
    print(torch.std(a))
    print()

    print("torch.var(a)")
    print(torch.var(a))
    print()

    print("torch.median(a)")
    print(torch.median(a))
    print()

    print("torch.mode(a), 默认dim=-1即最后一维")
    print(torch.mode(a))
    print()


def test_hist():
    print("==== hist ====")
    a = torch.rand(2, 2) * 10
    print("a")
    print(a)
    print()

    print("torch.histc(a, 6, min=0, max=0), min/max取0表示取tensor中的最大最小")
    print(torch.histc(a, 6, min=0, max=0))
    print()

    a = torch.randint(0, 10, [10])
    print("a = torch.randint(0, 10, [10])")
    print(a)
    print()

    print("torch.bincount(a) 注意：bincount只能处理1维的数据")
    print(torch.bincount(a))
    print()


if __name__ == '__main__':
    test_statistic()
    test_hist()
